{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import numpy.typing as npt\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "from moment.utils.config import Config\n",
    "from moment.utils.utils import parse_config\n",
    "from moment.data.generate_synthetic_data import SyntheticDataset\n",
    "from moment.models.base import BaseModel\n",
    "from moment.models.moment import MOMENT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DIMENSION_REDUCTION_METHOD = 'tsne' # 'tsne' 'pca'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def embed_timeseries_in_manifold(model: torch.nn.Module, \n",
    "                                 y: npt.NDArray, \n",
    "                                 device: torch.device, \n",
    "                                 input_mask: npt.NDArray = None,\n",
    "                                 dimension_reduction_method: str = 'tsne'):\n",
    "    y = y.to(device)\n",
    "    n_samples, _, seq_len = y.shape\n",
    "    model = model.to(device)\n",
    "    \n",
    "    if input_mask is None:\n",
    "        input_mask = torch.ones((n_samples, seq_len)).to(device)\n",
    "   \n",
    "    model.eval()\n",
    "    embeddings_manifold = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        outputs = model.embed(x_enc=y, input_mask=input_mask, reduction='mean')\n",
    "    embeddings = outputs.embeddings.detach().cpu().numpy()\n",
    "\n",
    "    if dimension_reduction_method == 'tsne':\n",
    "        embeddings_manifold = TSNE(n_components=2, n_jobs=5).fit_transform(embeddings)\n",
    "    elif dimension_reduction_method == 'pca':\n",
    "        embeddings_manifold = PCA(n_components=2).fit_transform(embeddings)\n",
    "\n",
    "    # Move tensors and models back to CPU\n",
    "    y = y.detach().cpu().numpy()\n",
    "    model = model.cpu()\n",
    "    input_mask = input_mask.detach().cpu().numpy()\n",
    "\n",
    "    return embeddings, embeddings_manifold\n",
    "\n",
    "def save_experiment_artifacts(filename : str, \n",
    "                              embeddings : npt.NDArray, \n",
    "                              y : npt.NDArray, \n",
    "                              c : npt.NDArray, \n",
    "                              embeddings_manifold : npt.NDArray):\n",
    "    # Save the data and embeddings\n",
    "    # \"../../assets/results/interpretability/frequency_artifacts.npz\"\n",
    "    with open(filename, \"wb\") as f:\n",
    "        np.savez(f, embeddings=embeddings, y=y, c=c, \n",
    "                 embeddings_manifold=embeddings_manifold)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defaults"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEFAULT_CONFIG_PATH = \"../../configs/default.yaml\"\n",
    "GPU_ID = 0\n",
    "run_name = \"fearless-planet-52\" "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parse config, build model and load pre-trained weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# with open('/home/extra_scratch/mgoswami/moment_checkpoints/avid-moon-55/MOMENT_checkpoint_5000.pth', 'rb') as f:\n",
    "#     checkpoint = torch.load(f)\n",
    "checkpoint = BaseModel.load_pretrained_weights(run_name=run_name, opt_steps=None)\n",
    "\n",
    "config = Config(config_file_path=DEFAULT_CONFIG_PATH, default_config_file_path=DEFAULT_CONFIG_PATH).parse()\n",
    "config['device'] = GPU_ID if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "args = parse_config(config)\n",
    "model = MOMENT(configs=args)\n",
    "model.load_state_dict(checkpoint[\"model_state_dict\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Frequency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "synthetic_dataset = SyntheticDataset(n_samples=1024, freq=1, freq_range=(1, 32), \n",
    "                                     noise_mean=0., noise_std=0.1, random_seed=13)\n",
    "\n",
    "y, c = synthetic_dataset.gen_sinusoids_with_varying_freq()\n",
    "n_samples = synthetic_dataset.n_samples\n",
    "seq_len = synthetic_dataset.seq_len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the data\n",
    "fig, axs = plt.subplots(1, 5, figsize=(30, 6), sharey=True)\n",
    "axs.flatten()\n",
    "for i, idx in enumerate(np.arange(0, n_samples+1, n_samples//4-1)):\n",
    "    axs[i].plot(y[idx].squeeze().numpy())\n",
    "    axs[i].set_xticks(\n",
    "        ticks=np.arange(0, seq_len+1, 128), \n",
    "        labels=np.arange(0, seq_len+1, 128), \n",
    "        fontdict={\"fontsize\" : 16})\n",
    "    axs[i].set_title(\"Frequency: {:.2f}\".format(c[:, 0][idx].squeeze().numpy(), ), fontsize=16)\n",
    "axs[0].set_yticks(\n",
    "        ticks=np.arange(-1.5, 1.5, 0.5), \n",
    "        labels=np.arange(-1.5, 1.5, 0.5),\n",
    "        fontdict={\"fontsize\" : 16})\n",
    "plt.savefig(\"../../assets/figures/interpretability/frequency_timeseries.pdf\", bbox_inches='tight') \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings, embeddings_manifold = embed_timeseries_in_manifold(\n",
    "    model=model, y=y, device=args.device, dimension_reduction_method=DIMENSION_REDUCTION_METHOD)\n",
    "\n",
    "# save_experiment_artifacts(filename=f\"../../assets/results/interpretability/frequency_artifacts_{DIMENSION_REDUCTION_METHOD}.npz\", \n",
    "#     embeddings=embeddings, y=y, c=c, embeddings_manifold=embeddings_manifold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title(f\"$y = \\sin(2c \\pi x) + \\epsilon$\", fontsize=20)\n",
    "plt.scatter(embeddings_manifold[:, 0], \n",
    "            embeddings_manifold[:, 1], c=c[:, 0].squeeze().numpy(), cmap='magma')\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.colorbar(boundaries=np.arange(\n",
    "    synthetic_dataset.freq_range[0], synthetic_dataset.freq_range[1]+1, 1))\n",
    "plt.savefig(f\"../../assets/figures/interpretability/frequency_artifacts_{DIMENSION_REDUCTION_METHOD}.pdf\", \n",
    "    bbox_inches='tight') \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Amplitude"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "synthetic_dataset = SyntheticDataset(\n",
    "    n_samples=2048, seq_len=512, freq=16, amplitude_range=(1/4, 4), \n",
    "    noise_mean=0., noise_std=0.1, random_seed=13)\n",
    "\n",
    "y, c = synthetic_dataset.gen_sinusoids_with_varying_amplitude()\n",
    "n_samples = synthetic_dataset.n_samples\n",
    "seq_len = synthetic_dataset.seq_len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the data\n",
    "fig, axs = plt.subplots(1, 5, figsize=(30, 6), sharey=True)\n",
    "axs.flatten()\n",
    "for i, idx in enumerate(np.arange(0, n_samples+1, n_samples//4-1)):\n",
    "    axs[i].plot(y[idx].squeeze().numpy())\n",
    "    axs[i].set_xticks(\n",
    "        ticks=np.arange(0, seq_len+1, 128), \n",
    "        labels=np.arange(0, seq_len+1, 128), \n",
    "        fontdict={\"fontsize\" : 16})\n",
    "    axs[i].set_title(\"Amplitude: {:.2f}\".format(c[:, 0][idx].squeeze().numpy(), ), fontsize=16)\n",
    "axs[0].set_yticks(\n",
    "        ticks=np.arange(-16, 16, 10), \n",
    "        labels=np.arange(-16, 16, 10),\n",
    "        fontdict={\"fontsize\" : 16})\n",
    "plt.savefig(\"../../assets/figures/interpretability/amplitude_timeseries.pdf\", bbox_inches='tight') \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings, embeddings_manifold = embed_timeseries_in_manifold(\n",
    "    model=model, y=y, device=args.device, dimension_reduction_method=DIMENSION_REDUCTION_METHOD)\n",
    "\n",
    "# save_experiment_artifacts(filename=f\"../../assets/results/interpretability/amplitude_artifacts_{DIMENSION_REDUCTION_METHOD}.npz\", \n",
    "#     embeddings=embeddings, y=y, c=c, embeddings_manifold=embeddings_manifold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title(f\"$y = c*\\sin(32\\pi x) + \\epsilon$\", fontsize=20)\n",
    "plt.scatter(embeddings_manifold[:, 0], \n",
    "            embeddings_manifold[:, 1], c=c[:, 0].squeeze().numpy(), cmap='magma')\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.colorbar(boundaries=np.arange(synthetic_dataset.amplitude_range[0], synthetic_dataset.amplitude_range[1]+1, 1))\n",
    "plt.savefig(f\"../../assets/figures/interpretability/amplitude_artifacts_{DIMENSION_REDUCTION_METHOD}.pdf\", \n",
    "    bbox_inches='tight') \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Trend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "synthetic_dataset = SyntheticDataset(n_samples=2048, freq=16, trend_range=(1/8, 8), \n",
    "                                     noise_mean=0., noise_std=0.1, random_seed=13)\n",
    "\n",
    "y, c = synthetic_dataset.gen_sinusoids_with_varying_trend()\n",
    "_, t = synthetic_dataset._generate_x()\n",
    "trend = t**c\n",
    "n_samples = synthetic_dataset.n_samples\n",
    "seq_len = synthetic_dataset.seq_len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the trends\n",
    "for i in range(0, len(trend), 32):\n",
    "    plt.plot(trend[i].squeeze().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the data\n",
    "fig, axs = plt.subplots(1, 5, figsize=(30, 6), sharey=True)\n",
    "axs.flatten()\n",
    "for i, idx in enumerate(np.arange(0, n_samples+1, n_samples//4-1)):\n",
    "    axs[i].plot(y[idx].squeeze().numpy())\n",
    "    axs[i].plot(trend[idx].squeeze().numpy())\n",
    "    axs[i].set_xticks(\n",
    "        ticks=np.arange(0, seq_len+1, 128), \n",
    "        labels=np.arange(0, seq_len+1, 128), \n",
    "        fontdict={\"fontsize\" : 16})\n",
    "    axs[i].set_title(\"Trend: {:.2f}\".format(c[:, 0][idx].squeeze().numpy(), ), fontsize=16)\n",
    "axs[0].set_yticks(\n",
    "        ticks=np.arange(-2, 2, 10), \n",
    "        labels=np.arange(-2, 2, 10),\n",
    "        fontdict={\"fontsize\" : 16})\n",
    "plt.savefig(\"../../assets/figures/interpretability/trend_timeseries.pdf\", bbox_inches='tight') \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings, embeddings_manifold = embed_timeseries_in_manifold(\n",
    "    model=model, y=y, device=args.device, dimension_reduction_method=DIMENSION_REDUCTION_METHOD)\n",
    "\n",
    "# save_experiment_artifacts(filename=f\"../../assets/results/interpretability/trend_artifacts_{DIMENSION_REDUCTION_METHOD}.npz\", \n",
    "#     embeddings=embeddings, y=y, c=c, embeddings_manifold=embeddings_manifold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title(f\"$y = x^c + \\sin(32\\pi x) + \\epsilon$\", fontsize=20)\n",
    "plt.scatter(embeddings_manifold[:, 0], \n",
    "            embeddings_manifold[:, 1], c=c[:, 0].squeeze().numpy(), cmap='magma')\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.colorbar(boundaries=np.arange(synthetic_dataset.trend_range[0], synthetic_dataset.trend_range[1]+1, 1))\n",
    "plt.savefig(f\"../../assets/figures/interpretability/trend_artifacts_{DIMENSION_REDUCTION_METHOD}.pdf\", bbox_inches='tight') \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Baseline shift"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "synthetic_dataset = SyntheticDataset(n_samples=2048, freq=16, baseline_range=(-2, 2), \n",
    "                                     noise_mean=0., noise_std=0.1, random_seed=13)\n",
    "\n",
    "y, c = synthetic_dataset.gen_sinusoids_with_varying_baseline()\n",
    "n_samples = synthetic_dataset.n_samples\n",
    "seq_len = synthetic_dataset.seq_len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the data\n",
    "fig, axs = plt.subplots(1, 5, figsize=(30, 6), sharey=True)\n",
    "axs.flatten()\n",
    "for i, idx in enumerate(np.arange(0, n_samples+1, n_samples//4-1)):\n",
    "    axs[i].plot(y[idx].squeeze().numpy())\n",
    "    axs[i].plot(c[idx].squeeze().numpy())\n",
    "    axs[i].set_xticks(\n",
    "        ticks=np.arange(0, seq_len+1, 128), \n",
    "        labels=np.arange(0, seq_len+1, 128), \n",
    "        fontdict={\"fontsize\" : 16})\n",
    "    axs[i].set_title(\"Baseline: {:.2f}\".format(c[:, 0][idx].squeeze().numpy(), ), fontsize=16)\n",
    "axs[0].set_yticks(\n",
    "        ticks=np.arange(-2, 2, 10), \n",
    "        labels=np.arange(-2, 2, 10),\n",
    "        fontdict={\"fontsize\" : 16})\n",
    "plt.savefig(\"../../assets/figures/interpretability/baseline_timeseries.pdf\", bbox_inches='tight') \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings, embeddings_manifold = embed_timeseries_in_manifold(\n",
    "    model=model, y=y, device=args.device, dimension_reduction_method=DIMENSION_REDUCTION_METHOD)\n",
    "\n",
    "# save_experiment_artifacts(filename=f\"../../assets/results/interpretability/trend_artifacts_{DIMENSION_REDUCTION_METHOD}.npz\", \n",
    "#     embeddings=embeddings, y=y, c=c, embeddings_manifold=embeddings_manifold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title(f\"$y = c + \\sin(32\\pi x) + \\epsilon$\", fontsize=20)\n",
    "plt.scatter(embeddings_manifold[:, 0], \n",
    "            embeddings_manifold[:, 1], c=c[:, 0].squeeze().numpy(), cmap='magma')\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.colorbar(boundaries=np.arange(synthetic_dataset.baseline_range[0], synthetic_dataset.baseline_range[1]+1, 1))\n",
    "plt.savefig(f\"../../assets/figures/interpretability/baseline_artifacts_{DIMENSION_REDUCTION_METHOD}.pdf\", bbox_inches='tight') \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Auto-correlation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "synthetic_dataset = SyntheticDataset(n_samples=512, freq=1, baseline_range=(-2, 2), \n",
    "                                     noise_mean=0., noise_std=0.1, random_seed=13)\n",
    "y_1, c_1 = synthetic_dataset.gen_sinusoids_with_varying_correlation()\n",
    "\n",
    "synthetic_dataset = SyntheticDataset(n_samples=512, freq=2, baseline_range=(-2, 2), \n",
    "                                     noise_mean=0., noise_std=0.1, random_seed=13)\n",
    "y_2, c_2 = synthetic_dataset.gen_sinusoids_with_varying_correlation()\n",
    "\n",
    "synthetic_dataset = SyntheticDataset(n_samples=512, freq=3, baseline_range=(-2, 2), \n",
    "                                     noise_mean=0., noise_std=0.1, random_seed=13)\n",
    "y_3, c_3 = synthetic_dataset.gen_sinusoids_with_varying_correlation()\n",
    "\n",
    "synthetic_dataset = SyntheticDataset(n_samples=512, freq=5, baseline_range=(-2, 2), \n",
    "                                     noise_mean=0., noise_std=0.1, random_seed=13)\n",
    "y_4, c_4 = synthetic_dataset.gen_sinusoids_with_varying_correlation()\n",
    "\n",
    "n_samples = 4*synthetic_dataset.n_samples\n",
    "seq_len = synthetic_dataset.seq_len\n",
    "\n",
    "y = torch.cat([y_1, y_2, y_3, y_4], dim=0)\n",
    "c = torch.cat([c_1, c_2, c_3, c_4], dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the data\n",
    "fig, axs = plt.subplots(1, 5, figsize=(30, 6), sharey=True)\n",
    "axs.flatten()\n",
    "for i, idx in enumerate(np.arange(0, n_samples+1, n_samples//4-1)):\n",
    "    axs[i].plot(y[idx].squeeze().numpy())\n",
    "    axs[i].set_xticks(\n",
    "        ticks=np.arange(0, seq_len+1, 128), \n",
    "        labels=np.arange(0, seq_len+1, 128), \n",
    "        fontdict={\"fontsize\" : 16})\n",
    "    axs[i].set_title(\"Offset: {:.2f}\".format(c[:, 0][idx].squeeze().numpy(), ), fontsize=16)\n",
    "axs[0].set_yticks(\n",
    "        ticks=np.arange(-2, 2, 10), \n",
    "        labels=np.arange(-2, 2, 10),\n",
    "        fontdict={\"fontsize\" : 16})\n",
    "plt.savefig(f\"../../assets/figures/interpretability/correlation_timeseries.pdf\", bbox_inches='tight') \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings, embeddings_manifold = embed_timeseries_in_manifold(\n",
    "    model=model, y=y, device=args.device, dimension_reduction_method=DIMENSION_REDUCTION_METHOD)\n",
    "\n",
    "# save_experiment_artifacts(filename=f\"../../assets/results/interpretability/correlation_artifacts_{DIMENSION_REDUCTION_METHOD}.npz\", \n",
    "#     embeddings=embeddings, y=y, c=c, embeddings_manifold=embeddings_manifold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title(f\"$y = \\sin(2\\pi f x + c) + \\epsilon, \\ c \\in [0, 2\\pi]$\", fontsize=20)\n",
    "\n",
    "wave_groups = {\n",
    "    \"1\" : slice(0, 512),\n",
    "    \"2\" : slice(512, 1024),\n",
    "    \"3\" : slice(1024, 1536),\n",
    "    \"5\" : slice(1536, 2048)\n",
    "}\n",
    "\n",
    "plt.scatter(embeddings_manifold[wave_groups[\"1\"], 0], \n",
    "            embeddings_manifold[wave_groups[\"1\"], 1], \n",
    "            c=c[wave_groups[\"1\"], 0].squeeze().numpy(), cmap='magma', marker='o')\n",
    "plt.scatter(embeddings_manifold[wave_groups[\"2\"], 0], \n",
    "            embeddings_manifold[wave_groups[\"2\"], 1], \n",
    "            c=c[wave_groups[\"2\"], 0].squeeze().numpy(), cmap='magma', marker='x')\n",
    "plt.scatter(embeddings_manifold[wave_groups[\"3\"], 0], \n",
    "            embeddings_manifold[wave_groups[\"3\"], 1], \n",
    "            c=c[wave_groups[\"3\"], 0].squeeze().numpy(), cmap='magma', marker='^')\n",
    "plt.scatter(embeddings_manifold[wave_groups[\"5\"], 0], \n",
    "            embeddings_manifold[wave_groups[\"5\"], 1], \n",
    "            c=c[wave_groups[\"5\"], 0].squeeze().numpy(), cmap='magma', marker='*')\n",
    "\n",
    "# TSNE\n",
    "plt.text(18, -30, \"$f=1$\", fontsize=16, color='darkred')\n",
    "plt.text(18, 20, \"$f=2$\", fontsize=16, color='darkred')\n",
    "plt.text(-30, 25, \"$f=3$\", fontsize=16, color='darkred')\n",
    "plt.text(-30, -22, \"$f=5$\", fontsize=16, color='darkred')\n",
    "\n",
    "# PCA\n",
    "# plt.text(0.2, -0.2, \"$f=1$\", fontsize=16, color='darkred')\n",
    "# plt.text(0.3, 0.2, \"$f=2$\", fontsize=16, color='darkred')\n",
    "# plt.text(-0.5, 0.3, \"$f=3$\", fontsize=16, color='darkred')\n",
    "# plt.text(-0.5, -0.2, \"$f=5$\", fontsize=16, color='darkred')\n",
    "\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.colorbar(boundaries=np.arange(0, 2*np.pi+1, 1))\n",
    "# plt.savefig(f\"../../assets/figures/interpretability/autocorrelation_artifacts_{DIMENSION_REDUCTION_METHOD}.pdf\", \n",
    "#     bbox_inches='tight') \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import stats\n",
    "\n",
    "\n",
    "mask_embedding = model.patch_embedding.mask_embedding.data.detach().cpu().numpy()\n",
    "_, (slope, intercept, r) = stats.probplot(\n",
    "    x=mask_embedding, sparams=(0, 1), dist=stats.norm, fit=True, rvalue=False, plot=plt)\n",
    "# dist = [stats.logistic, stats.norm, stats.t, stats.cauchy]\n",
    "\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.xlabel(\"Theoretical quantiles\", fontsize=16)\n",
    "plt.ylabel(\"Observed Values\", fontsize=16)\n",
    "plt.title(\"Probability Plot of Mask Embeddings\", fontsize=20)\n",
    "plt.text(r**2, -0.1, f\"$R^2 = {r**2:.4f}$\", fontsize=16)\n",
    "plt.savefig(\"../../assets/figures/interpretability/mask_embeddings.pdf\", bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Kolmogorov-Smirnov Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_result = stats.kstest(rvs=mask_embedding, cdf=\"norm\", alternative='two-sided')\n",
    "print(test_result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Input embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Output embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm, trange\n",
    "from moment.utils.short_univariate_classification_datasets import short_univariate_classification_datasets\n",
    "from moment.data.dataloader import get_timeseries_dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_test_dataloader(args):\n",
    "    args.dataset_names = args.full_file_path_and_name\n",
    "    args.data_split = 'test'\n",
    "    test_dataloader = get_timeseries_dataloader(args=args)\n",
    "    return test_dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEFAULT_CONFIG_PATH = \"../../configs/default.yaml\"\n",
    "GPU_ID = 0\n",
    "dataset_names = ['Crop', 'ElectricDevices', 'Wafer', 'ECG5000', 'ChlorineConcentration']\n",
    "config = Config(config_file_path=\"../../configs/classification/unsupervised_representation_learning.yaml\", \n",
    "                default_config_file_path=DEFAULT_CONFIG_PATH).parse()\n",
    "config['device'] = GPU_ID if torch.cuda.is_available() else 'cpu'\n",
    "args = parse_config(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset_name in dataset_names:   \n",
    "    args.full_file_path_and_name = f'/TimeseriesDatasets/classification/UCR/{dataset_name}/{dataset_name}_TEST.ts'\n",
    "    test_dataloader = get_test_dataloader(args)\n",
    "\n",
    "    embeddings = []\n",
    "    labels = []\n",
    "    for batch_x in tqdm(test_dataloader):\n",
    "        timeseries = batch_x.timeseries.float().to(args.device)\n",
    "        input_mask = batch_x.input_mask.long().to(args.device)\n",
    "        \n",
    "        _embeddings, _ = embed_timeseries_in_manifold(\n",
    "            model=model, y=timeseries, device=args.device, input_mask=input_mask, \n",
    "            dimension_reduction_method='none')\n",
    "\n",
    "        embeddings.append(_embeddings)\n",
    "        labels.append(batch_x.labels)\n",
    "\n",
    "    embeddings = np.concatenate(embeddings, axis=0)\n",
    "    labels = np.concatenate(labels, axis=0).squeeze()\n",
    "\n",
    "    for dimension_reduction_method in ['tsne', 'pca']:\n",
    "        if dimension_reduction_method == 'tsne':\n",
    "                embeddings_manifold = TSNE(n_components=2, n_jobs=10).fit_transform(embeddings)\n",
    "        elif dimension_reduction_method == 'pca':\n",
    "            embeddings_manifold = PCA(n_components=2).fit_transform(embeddings)\n",
    "\n",
    "        plt.title(f\"{dataset_name}\", fontsize=20)\n",
    "        plt.scatter(embeddings_manifold[:, 0], \n",
    "                    embeddings_manifold[:, 1], c=labels)\n",
    "        plt.xticks(fontsize=16)\n",
    "        plt.yticks(fontsize=16)\n",
    "        # plt.colorbar(boundaries=np.arange(labels.min(), labels.max()+1, 1))\n",
    "        plt.tick_params(axis='both', which='both', bottom=False, top=False, \n",
    "                        labelbottom=False, right=False, left=False, labelleft=False)\n",
    "        # Remove the box \n",
    "        ax = plt.gca()\n",
    "        ax.set_frame_on(False)\n",
    "        plt.savefig(f\"../../assets/figures/interpretability/{dataset_name}_{dimension_reduction_method}.pdf\", bbox_inches='tight') \n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Frequency analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reconstruct_timeseries(model: torch.nn.Module, \n",
    "                           y: npt.NDArray, \n",
    "                           device: torch.device, \n",
    "                           input_mask: npt.NDArray = None):\n",
    "    y = y.to(device)\n",
    "    n_samples, _, seq_len = y.shape\n",
    "    model = model.to(device)\n",
    "    \n",
    "    if input_mask is None:\n",
    "        input_mask = torch.ones((n_samples, seq_len)).to(device)\n",
    "   \n",
    "    model.eval()\n",
    "    embeddings_manifold = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        outputs = model.reconstruct(x_enc=y, input_mask=input_mask)\n",
    "    reconstruction = outputs.reconstruction.detach().cpu().numpy()\n",
    "\n",
    "    # Move tensors and models back to CPU\n",
    "    y = y.detach().cpu().numpy()\n",
    "    model = model.cpu()\n",
    "    input_mask = input_mask.detach().cpu().numpy()\n",
    "\n",
    "    return y, reconstruction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "synthetic_dataset = SyntheticDataset(n_samples=1024, freq=1, freq_range=(1, 96), \n",
    "                                     noise_mean=0., noise_std=0.1, random_seed=13)\n",
    "\n",
    "y, c = synthetic_dataset.gen_sinusoids_with_varying_freq()\n",
    "n_samples = synthetic_dataset.n_samples\n",
    "seq_len = synthetic_dataset.seq_len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timeseries, reconstruction = reconstruct_timeseries(model, y, args.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# idx = np.random.randint(0, n_samples)\n",
    "idx = 512\n",
    "plt.title(f\"$y = \\sin(2*{c[idx][0]:.1f} \\pi x) + \\epsilon$\", fontsize=20)\n",
    "plt.plot(timeseries[idx].squeeze(), color='darkblue', label='True')\n",
    "plt.plot(reconstruction[idx].squeeze(), color='red', linestyle='dashed', label='Reconstruction')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "error = np.mean((timeseries - reconstruction)**2, axis=-1).squeeze()\n",
    "plt.plot(c[:, 0].squeeze().numpy(), error, c='darkblue')\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.xlabel(f\"$y = \\sin(2*c \\pi x) + \\epsilon$\", fontsize=16)\n",
    "plt.ylabel(\"MSE\", fontsize=16)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "pd.plotting.autocorrelation_plot(error)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
